{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Modifying the Variational Strategy/Variational Distribution\n",
    "\n",
    "The predictive distribution for approximate GPs is given by\n",
    "\n",
    "$$\n",
    "p( \\mathbf f(\\mathbf x^*) ) = \\int_{\\mathbf u} p( f(\\mathbf x^*) \\mid \\mathbf u) \\: q(\\mathbf u) \\: d\\mathbf u,\n",
    "\\quad\n",
    "q(\\mathbf u) = \\mathcal N( \\mathbf m, \\mathbf S).\n",
    "$$\n",
    "\n",
    "$\\mathbf u$ represents the function values at the $m$ inducing points.\n",
    "Here, $\\mathbf m \\in \\mathbb R^m$ and $\\mathbf S \\in \\mathbb R^{m \\times m}$ are learnable parameters.\n",
    "\n",
    "If $m$ (the number of inducing points) is quite large, the number of learnable parameters in $\\mathbf S$ can be quite unwieldy.\n",
    "Furthermore, a large $m$ might make some of the computations rather slow.\n",
    "Here we show a few ways to use different [variational distributions](https://gpytorch.readthedocs.io/en/latest/variational.html#variational-distributions) and\n",
    "[variational strategies](https://gpytorch.readthedocs.io/en/latest/variational.html#variational-strategies) to accomplish this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experimental setup\n",
    "\n",
    "We're going to train an approximate GP on a medium-sized regression dataset, taken from the UCI repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "\n",
    "if not smoke_test and not os.path.isfile('../elevators.mat'):\n",
    "    print('Downloading \\'elevators\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')\n",
    "\n",
    "\n",
    "if smoke_test:  # this is for running the notebook in our testing framework\n",
    "    X, y = torch.randn(1000, 3), torch.randn(1000)\n",
    "else:\n",
    "    data = torch.Tensor(loadmat('../elevators.mat')['data'])\n",
    "    X = data[:, :-1]\n",
    "    X = X - X.min(0)[0]\n",
    "    X = 2 * (X / X.max(0)[0]) - 1\n",
    "    y = data[:, -1]\n",
    "\n",
    "\n",
    "train_n = int(floor(0.8 * len(X)))\n",
    "train_x = X[:train_n, :].contiguous()\n",
    "train_y = y[:train_n].contiguous()\n",
    "\n",
    "test_x = X[train_n:, :].contiguous()\n",
    "test_y = y[train_n:].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=500, shuffle=True)\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=500, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Some quick training/testing code\n",
    "\n",
    "This will allow us to train/test different model classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is for running the notebook in our testing framework\n",
    "num_epochs = 1 if smoke_test else 10\n",
    "\n",
    "\n",
    "# Our testing script takes in a GPyTorch MLL (objective function) class\n",
    "# and then trains/tests an approximate GP with it on the supplied dataset\n",
    "\n",
    "def train_and_test_approximate_gp(model_cls):\n",
    "    inducing_points = torch.randn(128, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)\n",
    "    model = model_cls(inducing_points)\n",
    "    likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "    mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.numel())\n",
    "    optimizer = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=0.1)\n",
    "    \n",
    "    if torch.cuda.is_available():\n",
    "        model = model.cuda()\n",
    "        likelihood = likelihood.cuda()\n",
    "\n",
    "    # Training\n",
    "    model.train()\n",
    "    likelihood.train()\n",
    "    epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=f\"Training {model_cls.__name__}\")\n",
    "    for i in epochs_iter:\n",
    "        # Within each iteration, we will go over each minibatch of data\n",
    "        for x_batch, y_batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            output = model(x_batch)\n",
    "            loss = -mll(output, y_batch)\n",
    "            epochs_iter.set_postfix(loss=loss.item())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "    # Testing\n",
    "    model.eval()\n",
    "    likelihood.eval()\n",
    "    means = torch.tensor([0.])\n",
    "    with torch.no_grad():\n",
    "        for x_batch, y_batch in test_loader:\n",
    "            preds = model(x_batch)\n",
    "            means = torch.cat([means, preds.mean.cpu()])\n",
    "    means = means[1:]\n",
    "    error = torch.mean(torch.abs(means - test_y.cpu()))\n",
    "    print(f\"Test {model_cls.__name__} MAE: {error.item()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Standard Approach\n",
    "\n",
    "As a default, we'll use the default [VariationalStrategy](https://gpytorch.readthedocs.io/en/latest/variational.html#id1) class with a [CholeskyVariationalDistribution](https://gpytorch.readthedocs.io/en/latest/variational.html#choleskyvariationaldistribution).\n",
    "The `CholeskyVariationalDistribution` class allows $\\mathbf S$ to be on any positive semidefinite matrix. This is the most general/expressive option for approximate GPs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StandardApproximateGP(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(-2))\n",
    "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super().__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ecb0000e03d4b95b5e9dd762da0c1bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Training StandardApproximateGP', max=10, style=ProgressStyle(…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test StandardApproximateGP MAE: 0.10098349303007126\n"
     ]
    }
   ],
   "source": [
    "train_and_test_approximate_gp(StandardApproximateGP)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reducing parameters\n",
    "\n",
    "### MeanFieldVariationalDistribution: a diagonal $\\mathbf S$ matrix \n",
    "\n",
    "One way to reduce the number of parameters is to restrict that $\\mathbf S$ is only diagonal. This is less expressive, but the number of parameters is now linear in $m$ instead of quadratic.\n",
    "\n",
    "All we have to do is take the previous example, and change `CholeskyVariationalDistribution` (full $\\mathbf S$ matrix) to [MeanFieldVariationalDistribution](https://gpytorch.readthedocs.io/en/latest/variational.html#meanfieldvariationaldistribution) (diagonal $\\mathbf S$ matrix)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MeanFieldApproximateGP(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(inducing_points.size(-2))\n",
    "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super().__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e95f8db268647f4b069758d9ed9dffd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Training MeanFieldApproximateGP', max=10, style=ProgressStyle…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test MeanFieldApproximateGP MAE: 0.07848489284515381\n"
     ]
    }
   ],
   "source": [
    "train_and_test_approximate_gp(MeanFieldApproximateGP)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DeltaVariationalDistribution: no $\\mathbf S$ matrix \n",
    "\n",
    "A more extreme method of reducing parameters is to get rid of $\\mathbf S$ entirely. This corresponds to learning a delta distribution ($\\mathbf u = \\mathbf m$) rather than a multivariate Normal distribution for $\\mathbf u$. In other words, this corresponds to performing MAP estimation rather than variational inference.\n",
    "\n",
    "In GPyTorch, getting rid of $\\mathbf S$ can be accomplished by using a [DeltaVariationalDistribution](https://gpytorch.readthedocs.io/en/latest/variational.html#deltavariationaldistribution)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MAPApproximateGP(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.DeltaVariationalDistribution(inducing_points.size(-2))\n",
    "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super().__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6cb9a96daea4abc8d092f82d38c0c40",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Training MAPApproximateGP', max=10, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test MAPApproximateGP MAE: 0.08846496045589447\n"
     ]
    }
   ],
   "source": [
    "train_and_test_approximate_gp(MAPApproximateGP)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reducing computation (through decoupled inducing points)\n",
    "\n",
    "One way to reduce the computational complexity is to use separate inducing points for the mean and covariance computations. The [Orthogonally Decoupled Variational Gaussian Processes](https://arxiv.org/abs/1809.08820) method of Salimbeni et al. (2018) uses more inducing points for the (computationally easy) mean computations and fewer inducing points for the (computationally intensive) covariance computations.\n",
    "\n",
    "In GPyTorch we implement this method in a modular way. The [OrthogonallyDecoupledVariationalStrategy](http://gpytorch.ai/variational.html#gpytorch.variational.OrthogonallyDecoupledVariationalStrategy) defines the variational strategy for the mean inducing points. It wraps an existing variational strategy/distribution that defines the covariance inducing points:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_orthogonal_vs(model, train_x):\n",
    "    mean_inducing_points = torch.randn(1000, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)\n",
    "    covar_inducing_points = torch.randn(100, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)\n",
    "\n",
    "    covar_variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "        model, covar_inducing_points,\n",
    "        gpytorch.variational.CholeskyVariationalDistribution(covar_inducing_points.size(-2)),\n",
    "        learn_inducing_locations=True\n",
    "    )\n",
    "\n",
    "    variational_strategy = gpytorch.variational.OrthogonallyDecoupledVariationalStrategy(\n",
    "        covar_variational_strategy, mean_inducing_points,\n",
    "        gpytorch.variational.DeltaVariationalDistribution(mean_inducing_points.size(-2)),\n",
    "    )\n",
    "    return variational_strategy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Putting it all together we have:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OrthDecoupledApproximateGP(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.DeltaVariationalDistribution(inducing_points.size(-2))\n",
    "        variational_strategy = make_orthogonal_vs(self, train_x)\n",
    "        super().__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef883369bfc84b919c2999823052f04b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Training OrthDecoupledApproximateGP', max=10, style=ProgressS…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test OrthDecoupledApproximateGP MAE: 0.08162340521812439\n"
     ]
    }
   ],
   "source": [
    "train_and_test_approximate_gp(OrthDecoupledApproximateGP)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
